from __future__ import annotations

from datetime import date, datetime
from typing import TYPE_CHECKING, Any

import numpy as np
import pytest

import polars as pl
import polars.selectors as cs
from polars.exceptions import ColumnNotFoundError
from polars.testing import assert_frame_equal, assert_series_equal

if TYPE_CHECKING:
    from polars._typing import PolarsDataType


def test_unique_predicate_pd() -> None:
    lf = pl.LazyFrame(
        {
            "x": ["abc", "abc"],
            "y": ["xxx", "xxx"],
            "z": [True, False],
        }
    )
    for subset in (
        ["x", "y"],
        pl.exclude("z"),
        pl.col("x", "y"),
        ~cs.starts_with("z"),
        [pl.col("x"), pl.col("y").str.len_bytes()],
    ):
        result = (
            lf.unique(subset=subset, maintain_order=True, keep="last")
            .filter(pl.col("z"))
            .collect()
        )
        expected = pl.DataFrame(
            schema={"x": pl.String, "y": pl.String, "z": pl.Boolean}
        )
        assert_frame_equal(result, expected)

    result = (
        lf.unique(subset=["x", "y"], maintain_order=True, keep="any")
        .filter(pl.col("z"))
        .collect()
    )
    expected = pl.DataFrame({"x": ["abc"], "y": ["xxx"], "z": [True]})
    assert_frame_equal(result, expected)

    # Issue #14595: filter should not naively be pushed past unique()
    for maintain_order in (True, False):
        for keep in ("first", "last", "any", "none"):
            q = (
                lf.unique("x", maintain_order=maintain_order, keep=keep)
                .filter(pl.col("x") == "abc")
                .filter(pl.col("z"))
            )
            assert_frame_equal(
                q.collect(optimizations=pl.QueryOptFlags(predicate_pushdown=False)),
                q.collect(),
            )


def test_unique_on_list_df() -> None:
    assert pl.DataFrame(
        {"a": [1, 2, 3, 4, 4], "b": [[1, 1], [2], [3], [4, 4], [4, 4]]}
    ).unique(maintain_order=True).to_dict(as_series=False) == {
        "a": [1, 2, 3, 4],
        "b": [[1, 1], [2], [3], [4, 4]],
    }


def test_list_unique() -> None:
    s = pl.Series("a", [[1, 2], [3], [1, 2], [4, 5], [2], [2]])
    assert s.unique(maintain_order=True).to_list() == [[1, 2], [3], [4, 5], [2]]
    assert s.arg_unique().to_list() == [0, 1, 3, 4]
    assert s.n_unique() == 4


def test_array_unique() -> None:
    s = pl.Series(
        "a",
        [
            np.array([1, 2]),
            np.array([3, 1]),
            np.array([1, 2]),
            np.array([4, 5]),
            np.array([2, 2]),
            np.array([2, 2]),
        ],
    )
    assert s.unique(maintain_order=True).to_list() == [[1, 2], [3, 1], [4, 5], [2, 2]]
    assert s.arg_unique().to_list() == [0, 1, 3, 4]
    assert s.n_unique() == 4


def test_unique_and_drop_stability() -> None:
    # see: 2898
    # the original cause was that we wrote:
    # expr_a = a.unique()
    # expr_a.filter(a.unique().is_not_null())
    # meaning that the a.unique was executed twice, which is an unstable algorithm
    df = pl.DataFrame({"a": [1, None, 1, None]})
    assert df.select(pl.col("a").unique().drop_nulls()).to_series()[0] == 1


def test_unique_empty() -> None:
    for dt in [pl.String, pl.Boolean, pl.Int32, pl.UInt32]:
        s = pl.Series([], dtype=dt)
        assert_series_equal(s.unique(), s)


def test_unique() -> None:
    ldf = pl.LazyFrame({"a": [1, 2, 2], "b": [3, 3, 3]})

    expected = pl.DataFrame({"a": [1, 2], "b": [3, 3]})
    assert_frame_equal(ldf.unique(maintain_order=True).collect(), expected)

    result = ldf.unique(subset="b", maintain_order=True).collect()
    expected = pl.DataFrame({"a": [1], "b": [3]})
    assert_frame_equal(result, expected)

    s0 = pl.Series("a", [1, 2, None, 2])
    expected_s = pl.Series("a", [1, 2, None])
    assert_series_equal(s0.unique(maintain_order=True), expected_s)
    assert_series_equal(s0.unique(maintain_order=False), expected_s, check_order=False)


def test_struct_unique_df() -> None:
    df = pl.DataFrame(
        {
            "numerical": [1, 2, 1],
            "struct": [{"x": 1, "y": 2}, {"x": 3, "y": 4}, {"x": 1, "y": 2}],
        }
    )
    df.select("numerical", "struct").unique().sort("numerical")


def test_sorted_unique_dates() -> None:
    out = (
        pl.DataFrame(
            [pl.Series("dt", [date(2015, 6, 24), date(2015, 6, 23)], dtype=pl.Date)]
        )
        .sort("dt")
        .unique(maintain_order=False)
    )
    expected = pl.DataFrame({"dt": [date(2015, 6, 23), date(2015, 6, 24)]})
    assert_frame_equal(out, expected, check_row_order=False)


@pytest.mark.parametrize("maintain_order", [True, False])
def test_unique_null(maintain_order: bool) -> None:
    s0 = pl.Series([])
    assert_series_equal(s0.unique(maintain_order=maintain_order), s0)

    s1 = pl.Series([None])
    assert_series_equal(s1.unique(maintain_order=maintain_order), s1)

    s2 = pl.Series([None, None])
    assert_series_equal(s2.unique(maintain_order=maintain_order), s1)


@pytest.mark.parametrize(
    ("input", "output"),
    [
        ([], []),
        (["a", "b", "b", "c"], ["a", "b", "c"]),
        ([None, "a", "b", "b"], [None, "a", "b"]),
    ],
)
def test_unique_categorical(input: list[str | None], output: list[str | None]) -> None:
    s = pl.Series(input, dtype=pl.Categorical)
    result = s.unique(maintain_order=True)
    expected = pl.Series(output, dtype=pl.Categorical)
    assert_series_equal(result, expected)

    result = s.unique(maintain_order=False)
    expected = pl.Series(output, dtype=pl.Categorical)
    assert_series_equal(result, expected, check_order=False)


def test_unique_with_null() -> None:
    df = pl.DataFrame(
        {
            "a": [1, 1, 2, 2, 3, 4],
            "b": ["a", "a", "b", "b", "c", "c"],
            "c": [None, None, None, None, None, None],
        }
    )
    expected_df = pl.DataFrame(
        {"a": [1, 2, 3, 4], "b": ["a", "b", "c", "c"], "c": [None, None, None, None]}
    )
    assert_frame_equal(df.unique(maintain_order=True), expected_df)


@pytest.mark.parametrize(
    ("input_json_data", "input_schema", "subset"),
    [
        ({"ID": [], "Name": []}, {"ID": pl.Int64, "Name": pl.String}, "id"),
        ({"ID": [], "Name": []}, {"ID": pl.Int64, "Name": pl.String}, ["age", "place"]),
        (
            {"ID": [1, 2, 1, 2], "Name": ["foo", "bar", "baz", "baa"]},
            {"ID": pl.Int64, "Name": pl.String},
            "id",
        ),
        (
            {"ID": [1, 2, 1, 2], "Name": ["foo", "bar", "baz", "baa"]},
            {"ID": pl.Int64, "Name": pl.String},
            ["age", "place"],
        ),
    ],
)
def test_unique_with_bad_subset(
    input_json_data: dict[str, list[Any]],
    input_schema: dict[str, PolarsDataType],
    subset: str | list[str],
) -> None:
    df = pl.DataFrame(input_json_data, schema=input_schema)

    with pytest.raises(ColumnNotFoundError, match="not found"):
        df.unique(subset=subset)


def test_categorical_unique_19409() -> None:
    df = pl.DataFrame({"x": [str(n % 50) for n in range(127)]}).cast(pl.Categorical)
    uniq = df.unique()
    assert uniq.height == 50
    assert uniq.null_count().item() == 0
    assert set(uniq["x"]) == set(df["x"])


def test_categorical_updated_revmap_unique_20233() -> None:
    s = pl.Series("a", ["A"], pl.Categorical)

    s = (
        pl.select(a=pl.when(True).then(pl.lit("C", pl.Categorical)))
        .select(a=pl.when(True).then(pl.lit("D", pl.Categorical)))
        .to_series()
    )

    assert_series_equal(s.unique(), pl.Series("a", ["D"], pl.Categorical))


@pytest.mark.parametrize(
    "subset",
    [
        "key",
        ["key"],
        pl.col("key"),
        pl.col("key").str.extract(r"^([a-z]+)_"),
        pl.exclude("value", "number"),
        cs.exclude("number", "value"),
    ],
)
def test_unique_check_order_20480(subset: Any) -> None:
    df = pl.DataFrame(
        [
            {
                "key": "some_key",
                "value": "second",
                "number": 2,
            },
            {
                "key": "some_key",
                "value": "first",
                "number": 1,
            },
        ]
    )
    assert (
        df.lazy()
        .sort("key", "number")
        .unique(subset=subset, keep="first")
        .collect()["number"]
        .item()
        == 1
    )


def test_predicate_pushdown_unique() -> None:
    q = (
        pl.LazyFrame({"id": [1, 2, 3]})
        .with_columns(pl.date(2024, 1, 1) + pl.duration(days=pl.Series([1, 2, 3])))  # type: ignore[arg-type]
        .unique()
    )

    print(q.filter(pl.col("id").is_in([1, 2, 3])).explain())
    assert not q.filter(pl.col("id").is_in([1, 2, 3])).explain().startswith("FILTER")
    assert q.filter(pl.col("id").sum() == pl.col("id")).explain().startswith("FILTER")


def test_unique_enum_19338() -> None:
    for data in [
        {"enum": ["A"]},
        [{"enum": "A"}],
    ]:
        df = pl.DataFrame(data, schema={"enum": pl.Enum(["A", "B", "C"])})
        result = df.select(pl.col("enum").unique())
        expected = pl.DataFrame(
            {"enum": ["A"]}, schema={"enum": pl.Enum(["A", "B", "C"])}
        )
        assert_frame_equal(result, expected)


def test_unique_nan_12950() -> None:
    df = pl.DataFrame({"x": float("nan")})
    result = df.unique()
    assert_frame_equal(result, df)


def test_unique_lengths_21654() -> None:
    for n in range(0, 1000, 37):
        df = pl.DataFrame({"x": pl.int_range(n, eager=True)})
        assert df.unique().height == n


def test_unique_keep_last_with_slice_22470() -> None:
    lf = pl.LazyFrame({"x": [0, 1, 2, 3, 4, 5, 6, 7, 3, 4, 5, 6, 7, 8, 9, 10]})
    result = lf.unique(keep="last", maintain_order=True).slice(3, 4).collect()
    expected = pl.DataFrame({"x": [3, 4, 5, 6]})
    assert_frame_equal(result, expected)


def test_unique_booleans_22753() -> None:
    assert_series_equal(
        pl.Series([None, None, True] + [None] * 128).slice(3).unique(),
        pl.Series([None], dtype=pl.Boolean()),
    )

    assert_series_equal(
        pl.Series([None, None, True]).head(2).unique(),
        pl.Series([None], dtype=pl.Boolean()),
    )


def test_unique_i128_24231() -> None:
    df = pl.Series(
        [-(1 << 127), -(1 << 126), 1 << 125, 1 << 126], dtype=pl.Int128
    ).to_frame("a")
    assert_frame_equal(df, df.unique(), check_row_order=False)


def test_unique_keep_none_24260() -> None:
    data = pl.DataFrame({"a": [1, 3, 2], "b": [4, 4, 6]})
    out = data.lazy().unique(subset="b", keep="none").collect()
    assert_frame_equal(out, pl.DataFrame({"a": [2], "b": [6]}))


def test_unique_column_subset_25233() -> None:
    df = pl.DataFrame(
        {
            "time": pl.datetime_range(
                start=datetime(2021, 12, 16),
                end=datetime(2021, 12, 16, 1, 30),
                interval="15m",
                eager=True,
            ),
            "op_type": ["x", "y", "x", "y", "x", "x", "y"],
            "value": [1, 2, 3, 4, 6, 7, 8],
        }
    )

    result = df.unique(subset="op_type")
    assert result.height == 2
    assert result.select(pl.col.op_type.n_unique()).item() == 2


@pytest.mark.parametrize("stable", [False, True])
def test_unique_list_arr_non_numeric(stable: bool) -> None:
    assert_series_equal(
        pl.Series([["A"], ["B"], ["A"]]).unique(maintain_order=stable),
        pl.Series([["A"], ["B"]]),
        check_order=stable,
    )

    assert_series_equal(
        pl.Series([["A"], ["B"], ["A"]], dtype=pl.Array(pl.String, 1)).unique(
            maintain_order=stable
        ),
        pl.Series([["A"], ["B"]], dtype=pl.Array(pl.String, 1)),
        check_order=stable,
    )


@pytest.mark.parametrize("maintain_order", [False, True])
@pytest.mark.parametrize("stable", [False, True])
def test_unique_on_literal_in_agg(maintain_order: bool, stable: bool) -> None:
    df = (
        pl.DataFrame({"a": [0, 1]})
        .group_by("a", maintain_order=maintain_order)
        .agg(b=pl.lit(1, pl.Int64).unique(maintain_order=stable))
    )
    assert_frame_equal(
        df,
        pl.DataFrame(
            {"a": [0, 1], "b": [[1], [1]]}, schema_overrides={"b": pl.List(pl.Int64)}
        ),
        check_row_order=maintain_order,
    )
